import torch.nn as nn

class NormalizingFlow(nn.Module):

    def __init__(self, layers, prior):
        super(NormalizingFlow, self).__init__()

        self.layers = nn.ModuleList(layers)
        self.prior = prior


    def log_prob(self, X, C):
    
        log_likelihood = None

        for layer in self.layers:
            X, change = layer.f(X, C)
            if log_likelihood is not None:
                log_likelihood = log_likelihood + change
            else:
                log_likelihood = change
        log_likelihood = log_likelihood + self.prior.log_prob(X)

        return log_likelihood


    def sample(self, C):
        
        if type(C) == type(1):
            n = C
            C = None
        else:
            n = len(C)

        X = self.prior.sample((n,))
        for layer in self.layers[::-1]:
            X = layer.g(X, C)

        return X
